{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "u_OIwDav0A4W",
      "metadata": {
        "id": "u_OIwDav0A4W"
      },
      "outputs": [],
      "source": [
        "# Copyright 2025 Google LLC\n",
        "#\n",
        "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "#     https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "id": "T_oWMJLg0fLk",
      "metadata": {
        "id": "T_oWMJLg0fLk"
      },
      "source": [
        "# Quick start with Model Garden - TxGemma\n",
        "\n",
        "<table><tbody><tr>\n",
        "  <td style=\"text-align: center\">\n",
        "    <a href=\"https://console.cloud.google.com/vertex-ai/colab/import/https:%2F%2Fraw.githubusercontent.com%2Fgoogle-gemini%2Fgemma-cookbook%2Fmain%2FTxGemma%2F[TxGemma]Quickstart_with_Model_Garden.ipynb\">\n",
        "      <img alt=\"Google Cloud Colab Enterprise logo\" src=\"https://lh3.googleusercontent.com/JmcxdQi-qOpctIvWKgPtrzZdJJK-J3sWE1RsfjZNwshCFgE_9fULcNpuXYTilIR2hjwN\" width=\"32px\"><br> Run in Colab Enterprise\n",
        "    </a>\n",
        "  </td>\n",
        "  <td style=\"text-align: center\">\n",
        "    <a href=\"https://github.com/google-gemini/gemma-cookbook/blob/main/TxGemma/%5BTxGemma%5DQuickstart_with_Model_Garden.ipynb\">\n",
        "      <img alt=\"GitHub logo\" src=\"https://github.githubassets.com/assets/GitHub-Mark-ea2971cee799.png\" width=\"32px\"><br> View on GitHub\n",
        "    </a>\n",
        "  </td>\n",
        "</tr></tbody></table>"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "JsEU-DK7DJcv",
      "metadata": {
        "id": "JsEU-DK7DJcv"
      },
      "source": [
        "## Overview\n",
        "\n",
        "This notebook demonstrates how to use TxGemma in Vertex AI to generate predictions from therapeutic related text prompts using two methods for getting predictions:\n",
        "\n",
        "* **Online predictions** are synchronous requests that are made to the endpoint deployed from Model Garden and are served with low latency. Online predictions are useful if the responses are being used in production. The cost for online prediction is based on the time a virtual machine spends waiting in an active state (an endpoint with a deployed model) to handle prediction requests.\n",
        "\n",
        "* **Batch predictions** are asynchronous requests that are run on a set number of prompts specified in a single job. They are made directly to an uploaded model and do not use an endpoint deployed from Model Garden. Batch predictions are useful if you want to generate responses for a large number of prompts for use in training and don't require low latency. The cost for batch prediction is based on the time a virtual machine spends running your prediction job.\n",
        "\n",
        "Vertex AI makes it easy to serve your model and make it accessible to the world. Learn more about [Vertex AI](https://cloud.google.com/vertex-ai/docs/start/introduction-unified-platform).\n",
        "\n",
        "### Objectives\n",
        "\n",
        "- Deploy TxGemma to a Vertex AI Endpoint and get online predictions.\n",
        "- Upload TxGemma to Vertex AI Model Registry and get batch predictions.\n",
        "\n",
        "### Costs\n",
        "\n",
        "This tutorial uses billable components of Google Cloud:\n",
        "\n",
        "* Vertex AI\n",
        "* Cloud Storage\n",
        "\n",
        "Learn about [Vertex AI pricing](https://cloud.google.com/vertex-ai/pricing), [Cloud Storage pricing](https://cloud.google.com/storage/pricing), and use the [Pricing Calculator](https://cloud.google.com/products/calculator/) to generate a cost estimate based on your projected usage."
      ]
    },
    {
      "cell_type": "markdown",
      "id": "Fe_iHV1RDA3C",
      "metadata": {
        "id": "Fe_iHV1RDA3C"
      },
      "source": [
        "## Before you begin"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "-nRhKI2hEpzn",
      "metadata": {
        "cellView": "form",
        "id": "-nRhKI2hEpzn"
      },
      "outputs": [],
      "source": [
        "# @title Import packages and define common functions\n",
        "\n",
        "import datetime\n",
        "import importlib\n",
        "import io\n",
        "import json\n",
        "import os\n",
        "import uuid\n",
        "\n",
        "import google.auth\n",
        "from google.cloud import aiplatform\n",
        "from google.cloud import storage\n",
        "import openai\n",
        "\n",
        "if not os.path.isdir(\"vertex-ai-samples\"):\n",
        "  ! git clone https://github.com/GoogleCloudPlatform/vertex-ai-samples.git\n",
        "\n",
        "common_util = importlib.import_module(\n",
        "    \"vertex-ai-samples.community-content.vertex_model_garden.model_oss.notebook_util.common_util\"\n",
        ")\n",
        "\n",
        "models, endpoints = {}, {}"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "2_aYhzoEDMCf",
      "metadata": {
        "cellView": "form",
        "id": "2_aYhzoEDMCf"
      },
      "outputs": [],
      "source": [
        "# @title Set up Google Cloud environment\n",
        "\n",
        "# @markdown #### Prerequisites\n",
        "\n",
        "# @markdown 1. Make sure that [billing is enabled](https://cloud.google.com/billing/docs/how-to/modify-project) for your project.\n",
        "\n",
        "# @markdown 2. Make sure that either the Compute Engine API is enabled or that you have the [Service Usage Admin](https://cloud.google.com/iam/docs/understanding-roles#serviceusage.serviceUsageAdmin) (`roles/serviceusage.serviceUsageAdmin`) role to enable the API.\n",
        "\n",
        "# @markdown This section sets the default Google Cloud project and region, enables the Compute Engine API (if not already enabled), and initializes the Vertex AI API.\n",
        "\n",
        "# Get the default project ID.\n",
        "PROJECT_ID = os.environ[\"GOOGLE_CLOUD_PROJECT\"]\n",
        "\n",
        "# Get the default region for launching jobs.\n",
        "REGION = os.environ[\"GOOGLE_CLOUD_REGION\"]\n",
        "\n",
        "# Enable the Compute Engine API, if not already.\n",
        "print(\"Enabling Compute Engine API.\")\n",
        "! gcloud services enable compute.googleapis.com\n",
        "\n",
        "# Initialize Vertex AI API.\n",
        "print(\"Initializing Vertex AI API.\")\n",
        "aiplatform.init(project=PROJECT_ID, location=REGION)"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "8oEhN-_6Y6BN",
      "metadata": {
        "id": "8oEhN-_6Y6BN"
      },
      "source": [
        "## Format prompts for therapeutic tasks\n",
        "\n",
        "The following sections in this notebook demonstrate prompting TxGemma for therapeutic development tasks from [Therapeutics Data Commons](https://tdcommons.ai/) (TDC).\n",
        "\n",
        "For these predictive tasks, prompts should be formatted according to the TDC structure, including:\n",
        "\n",
        "- **Instruction:** Briefly describes the task.\n",
        "\n",
        "- **Context:** Provides 2-3 sentences of relevant biochemical background, derived from TDC descriptions and literature.\n",
        "\n",
        "- **Question:** Queries a specific therapeutic property, incorporating textual representations of therapeutics and/or targets as inputs.\n",
        "\n",
        "  - Inputs can include SMILES strings, amino acid sequences, nucleotide sequences, and natural language text.\n",
        "\n",
        "  - Optional few-shot examples can be provided."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "Cix7nel8CcqK",
      "metadata": {
        "cellView": "form",
        "id": "Cix7nel8CcqK"
      },
      "outputs": [],
      "source": [
        "# @title Load prompt template\n",
        "\n",
        "# @markdown First, load a JSON file that contains the prompt format for various TDC tasks.\n",
        "\n",
        "! gcloud storage cp gs://healthai-us/txgemma/templates/tdc_prompts.json tdc_prompts.json\n",
        "\n",
        "with open(\"tdc_prompts.json\", \"r\") as f:\n",
        "    tdc_prompts_json = json.load(f)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "n6VW-6UVCgjz",
      "metadata": {
        "cellView": "form",
        "id": "n6VW-6UVCgjz"
      },
      "outputs": [],
      "source": [
        "# @title Prepare sample prompt\n",
        "\n",
        "# @markdown Construct a prompt using the template and an input drug SMILES string from the [BBB (Blood-Brain Barrier), Martins et al.](https://tdcommons.ai/single_pred_tasks/adme#bbb-blood-brain-barrier-martins-et-al) dataset. This prompt will be used for generating predictions in the next sections.\n",
        "\n",
        "# @markdown **Note:** The prompt should not be modified from the template except for replacing the inputs.",
        "\n",
        "# Example task and input\n",
        "task_name = \"BBB_Martins\"\n",
        "input_type = \"{Drug SMILES}\"\n",
        "drug_smiles = \"CN1C(=O)CN=C(C2=CCCCC2)c2cc(Cl)ccc21\"\n",
        "\n",
        "TDC_PROMPT = tdc_prompts_json[task_name].replace(input_type, drug_smiles)\n",
        "print(\"Formatted prompt:\\n\")\n",
        "print(TDC_PROMPT)"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "DpGxMQxndGcG",
      "metadata": {
        "id": "DpGxMQxndGcG"
      },
      "source": [
        "## Get online predictions"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "55wdOLrEK8kg",
      "metadata": {
        "cellView": "form",
        "id": "55wdOLrEK8kg"
      },
      "outputs": [],
      "source": [
        "# @title Import deployed model\n",
        "\n",
        "# @markdown To get [online predictions](https://cloud.google.com/vertex-ai/docs/predictions/get-online-predictions), you will need a TxGemma [Vertex AI Endpoint](https://cloud.google.com/vertex-ai/docs/general/deployment) that has been deployed from Model Garden. If you have not already done so, go to the [TxGemma model card](https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/txgemma) in Model Garden and click \"Deploy\" to deploy the model.\n",
        "\n",
        "# @markdown This section gets the Vertex AI Endpoint resource that you deployed from Model Garden to use for online predictions.\n",
        "\n",
        "# @markdown Fill in the endpoint ID and region below. You can find your deployed endpoint on the [Vertex AI online prediction page](https://console.cloud.google.com/vertex-ai/online-prediction/endpoints).\n",
        "\n",
        "ENDPOINT_ID = \"\"  # @param {type: \"string\", placeholder:\"e.g. 123456789\"}\n",
        "ENDPOINT_REGION = \"\"  # @param {type: \"string\", placeholder:\"e.g. us-central1\"}\n",
        "\n",
        "endpoints[\"endpoint\"] = aiplatform.Endpoint(\n",
        "    endpoint_name=ENDPOINT_ID,\n",
        "    project=PROJECT_ID,\n",
        "    location=ENDPOINT_REGION,\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "QmkhDkBicAZV",
      "metadata": {
        "id": "QmkhDkBicAZV"
      },
      "source": [
        "### Explore predictive capabilities\n",
        "\n",
        "TxGemma models are designed to process and understand information related to various therapeutic modalities and targets, including small molecules, proteins, nucleic acids, diseases, and cell lines, and can generate predictions on a broad set of therapeutic development tasks.\n",
        "\n",
        "You can send [online prediction](https://cloud.google.com/vertex-ai/docs/predictions/get-online-predictions) requests to the endpoint with formatted text prompts."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "trU5YDBEHwn-",
      "metadata": {
        "cellView": "form",
        "id": "trU5YDBEHwn-"
      },
      "outputs": [],
      "source": [
        "# @title #### Run inference on a therapeutic task\n",
        "\n",
        "# @markdown This section demonstrates prompting TxGemma for a predictive task from TDC.\n",
        "\n",
        "# Prepare request instance using sample prompt for a TDC task\n",
        "instances = [\n",
        "    {\n",
        "        \"prompt\": TDC_PROMPT,\n",
        "        \"max_tokens\": 8,\n",
        "        \"temperature\": 0\n",
        "    },\n",
        "]\n",
        "\n",
        "response = endpoints[\"endpoint\"].predict(instances=instances)\n",
        "predictions = response.predictions\n",
        "\n",
        "print(predictions[0])"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "V4WlXpx2fGN4",
      "metadata": {
        "id": "V4WlXpx2fGN4"
      },
      "source": [
        "### Explore conversational capabilities with TxGemma-Chat\n",
        "\n",
        "TxGemma features conversational models that add reasoning and explainability to predictions and can be used in multi-turn interactions. Their conversational ability comes at the expense of some predictive power.\n",
        "\n",
        "You can send [online prediction](https://cloud.google.com/vertex-ai/docs/predictions/get-online-predictions) requests to the endpoint using the [OpenAI SDK](https://github.com/openai/openai-python).\n",
        "\n",
        "**For this section, make sure that you are using a TxGemma-Chat model variant.**"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "Zawh4-BZbcMR",
      "metadata": {
        "cellView": "form",
        "id": "Zawh4-BZbcMR"
      },
      "outputs": [],
      "source": [
        "# @title #### Ask questions in a multi-turn conversation\n",
        "\n",
        "# @markdown This section demonstrates prompting TxGemma for conversational use.\n",
        "\n",
        "# @markdown In this example, first prompt the model to answer a question regarding a predictive task using the TDC format. Then, ask a follow-up question requesting the model to provide its reasoning for the predicted answer.\n",
        "\n",
        "import google.auth\n",
        "from IPython.display import display, Markdown\n",
        "import openai\n",
        "\n",
        "creds, project = google.auth.default()\n",
        "auth_req = google.auth.transport.requests.Request()\n",
        "creds.refresh(auth_req)\n",
        "\n",
        "ENDPOINT_RESOURCE_URL = (\n",
        "    f\"https://{ENDPOINT_REGION}-aiplatform.googleapis.com/v1beta1/\"\n",
        "    f\"projects/{PROJECT_ID}/\"\n",
        "    f\"locations/{ENDPOINT_REGION}/\"\n",
        "    f\"endpoints/{ENDPOINT_ID}\"\n",
        ")\n",
        "\n",
        "client = openai.OpenAI(base_url=ENDPOINT_RESOURCE_URL, api_key=creds.token)\n",
        "\n",
        "questions = [\n",
        "    TDC_PROMPT,  # Initial question is a predictive task from TDC\n",
        "    \"Explain your reasoning based on the molecule structure.\"\n",
        "]\n",
        "\n",
        "messages = []\n",
        "\n",
        "display(Markdown(\"\\n\\n---\\n\\n\"))\n",
        "for question in questions:\n",
        "    display(Markdown(f\"**User:**\\n\\n{question}\\n\\n---\\n\\n\"))\n",
        "    messages.append(\n",
        "        { \"role\": \"user\", \"content\": question },\n",
        "    )\n",
        "    model_response = client.chat.completions.create(\n",
        "        model=\"\",\n",
        "        messages=messages,\n",
        "        temperature=0,\n",
        "        max_tokens=512,\n",
        "    )\n",
        "    response_content = model_response.choices[0].message.content\n",
        "    display(Markdown(f\"**TxGemma:**\\n\\n{response_content}\\n\\n---\\n\\n\"))\n",
        "    messages.append(\n",
        "        { \"role\": \"assistant\", \"content\": response_content},\n",
        "    )"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "4GjsXPnATk2B",
      "metadata": {
        "id": "4GjsXPnATk2B"
      },
      "source": [
        "## Get batch predictions"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "XeAKjvsUuW8S",
      "metadata": {
        "cellView": "form",
        "id": "XeAKjvsUuW8S"
      },
      "outputs": [],
      "source": [
        "# @title Get access to TxGemma\n",
        "\n",
        "# @markdown The prediction container directly loads the model from Hugging Face Hub.\n",
        "\n",
        "# @markdown To enable access to the TxGemma models, you must provide a Hugging Face User Access Token. You can follow the [Hugging Face documentation](https://huggingface.co/docs/hub/en/security-tokens) to create a **read** access token and specify it in the `HF_TOKEN` field below.\n",
        "\n",
        "HF_TOKEN = \"\"  # @param {type:\"string\", placeholder:\"Hugging Face User Access Token\"}"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "9A6-24lJVM8Y",
      "metadata": {
        "cellView": "form",
        "id": "9A6-24lJVM8Y"
      },
      "outputs": [],
      "source": [
        "# @title Upload model to Vertex AI Model Registry\n",
        "\n",
        "# @markdown To get [batch predictions](https://cloud.google.com/vertex-ai/docs/predictions/get-batch-predictions), you must first upload the prebuilt TxGemma model to [Vertex AI Model Registry](https://cloud.google.com/vertex-ai/docs/model-registry/introduction). Batch prediction requests are made directly to a model in Model Registry without deploying to an endpoint.\n",
        "\n",
        "MODEL_VARIANT = \"2b-predict\"  # @param [\"2b-predict\", \"9b-chat\", \"9b-predict\", \"27b-chat\", \"27b-predict\"]\n",
        "\n",
        "MODEL_ID = f\"txgemma-{MODEL_VARIANT}\"\n",
        "\n",
        "# The pre-built serving docker image.\n",
        "SERVE_DOCKER_URI = \"us-docker.pkg.dev/vertex-ai/vertex-vision-model-garden-dockers/pytorch-vllm-serve:20250114_0916_RC00_maas\"\n",
        "\n",
        "# This notebook uses Nvidia L4 GPUs for demonstration.\n",
        "# See https://cloud.google.com/vertex-ai/docs/predictions/configure-compute#batch_prediction\n",
        "# for details on configuring compute for Vertex AI batch predictions.\n",
        "if \"2b\" in MODEL_ID:\n",
        "    MACHINE_TYPE = \"g2-standard-12\"\n",
        "    ACCELERATOR_TYPE = \"NVIDIA_L4\"\n",
        "    ACCELERATOR_COUNT = 1\n",
        "elif \"9b\" in MODEL_ID:\n",
        "    MACHINE_TYPE = \"g2-standard-24\"\n",
        "    ACCELERATOR_TYPE = \"NVIDIA_L4\"\n",
        "    ACCELERATOR_COUNT = 2\n",
        "else:\n",
        "    MACHINE_TYPE = \"g2-standard-48\"\n",
        "    ACCELERATOR_TYPE = \"NVIDIA_L4\"\n",
        "    ACCELERATOR_COUNT = 4\n",
        "\n",
        "\n",
        "def upload_model(\n",
        "    model_name: str,\n",
        "    model_id: str,\n",
        "    tensor_parallel_size: int = ACCELERATOR_COUNT,\n",
        "    gpu_memory_utilization: float = 0.95,\n",
        ") -> aiplatform.Model:\n",
        "    vllm_args = [\n",
        "        \"python\",\n",
        "        \"-m\",\n",
        "        \"vllm.entrypoints.api_server\",\n",
        "        \"--host=0.0.0.0\",\n",
        "        \"--port=8080\",\n",
        "        f\"--model={model_id}\",\n",
        "        f\"--tensor-parallel-size={tensor_parallel_size}\",\n",
        "        f\"--swap-space=16\",\n",
        "        f\"--gpu-memory-utilization={gpu_memory_utilization}\",\n",
        "        \"--enable-chunked-prefill\",\n",
        "        \"--disable-log-stats\",\n",
        "    ]\n",
        "\n",
        "    env_vars = {\n",
        "        \"MODEL_ID\": model_id,\n",
        "        \"DEPLOY_SOURCE\": \"notebook\",\n",
        "        \"HF_TOKEN\": HF_TOKEN,\n",
        "    }\n",
        "\n",
        "    model = aiplatform.Model.upload(\n",
        "        display_name=model_name,\n",
        "        serving_container_image_uri=SERVE_DOCKER_URI,\n",
        "        serving_container_args=vllm_args,\n",
        "        serving_container_ports=[8080],\n",
        "        serving_container_predict_route=\"/generate\",\n",
        "        serving_container_health_route=\"/ping\",\n",
        "        serving_container_environment_variables=env_vars,\n",
        "    )\n",
        "\n",
        "    return model\n",
        "\n",
        "\n",
        "models[\"model\"] = upload_model(\n",
        "    model_name=common_util.get_job_name_with_datetime(prefix=MODEL_ID),\n",
        "    model_id=f\"google/{MODEL_ID}\",\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "_T9UDGV5TbrR",
      "metadata": {
        "cellView": "form",
        "id": "_T9UDGV5TbrR"
      },
      "outputs": [],
      "source": [
        "# @title Set up Google Cloud resources\n",
        "\n",
        "# @markdown This section sets up a [Cloud Storage bucket](https://cloud.google.com/storage/docs/creating-buckets) for storing batch prediction inputs and outputs and gets the [Compute Engine default service account](https://cloud.google.com/compute/docs/access/service-accounts#default_service_account) which will be used to run the batch prediction jobs.\n",
        "\n",
        "# @markdown 1. Make sure that you have the following required roles:\n",
        "# @markdown - [Storage Admin](https://cloud.google.com/iam/docs/understanding-roles#storage.admin) (`roles/storage.admin`) to create and use Cloud Storage buckets\n",
        "# @markdown - [Service Account User](https://cloud.google.com/iam/docs/understanding-roles#iam.serviceAccountUser) (`roles/iam.serviceAccountUser`) on either the project or the Compute Engine default service account\n",
        "\n",
        "# @markdown 2. Set up a Cloud Storage bucket.\n",
        "# @markdown - A new bucket will automatically be created for you.\n",
        "# @markdown - [Optional] To use an existing bucket, specify the `gs://` bucket URI. The specified Cloud Storage bucket should be located in the same region as where the notebook was launched. Note that a multi-region bucket (e.g. \"us\") is not considered a match for a single region (e.g. \"us-central1\") covered by the multi-region range.\n",
        "\n",
        "BUCKET_URI = \"\"  # @param {type:\"string\", placeholder:\"[Optional] Cloud Storage bucket URI\"}\n",
        "\n",
        "# Cloud Storage bucket for storing batch prediction artifacts.\n",
        "# A unique bucket will be created for the purpose of this notebook. If you\n",
        "# prefer using your own GCS bucket, change the value of BUCKET_URI above.\n",
        "if BUCKET_URI is None or BUCKET_URI.strip() == \"\":\n",
        "    now = datetime.datetime.now().strftime(\"%Y%m%d%H%M%S\")\n",
        "    BUCKET_URI = f\"gs://{PROJECT_ID}-tmp-{now}-{str(uuid.uuid4())[:4]}\"\n",
        "    BUCKET_NAME = \"/\".join(BUCKET_URI.split(\"/\")[:3])\n",
        "    ! gcloud storage buckets create --location {REGION} {BUCKET_URI}\n",
        "else:\n",
        "    assert BUCKET_URI.startswith(\"gs://\"), \"BUCKET_URI must start with `gs://`.\"\n",
        "    BUCKET_NAME = \"/\".join(BUCKET_URI.split(\"/\")[:3])\n",
        "    shell_output = ! gcloud storage buckets describe {BUCKET_NAME} | grep \"location:\" | sed \"s/location://\"\n",
        "    bucket_region = shell_output[0].strip().lower()\n",
        "    if bucket_region != REGION:\n",
        "        raise ValueError(\n",
        "            f\"Bucket region {bucket_region} is different from notebook region {REGION}\"\n",
        "        )\n",
        "print(f\"Using this Cloud Storage Bucket: {BUCKET_URI}\")\n",
        "\n",
        "# Service account used for running the prediction container.\n",
        "# Gets the Compute Engine default service account. If you prefer using your own\n",
        "# custom service account, change the value of SERVICE_ACCOUNT below.\n",
        "shell_output = ! gcloud projects describe $PROJECT_ID\n",
        "project_number = shell_output[-1].split(\":\")[1].strip().replace(\"'\", \"\")\n",
        "SERVICE_ACCOUNT = f\"{project_number}-compute@developer.gserviceaccount.com\"\n",
        "print(\"Using this service account:\", SERVICE_ACCOUNT)"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "72ZyqSb0pBvP",
      "metadata": {
        "id": "72ZyqSb0pBvP"
      },
      "source": [
        "### Predict\n",
        "\n",
        "You can send [batch prediction requests](https://cloud.google.com/vertex-ai/docs/predictions/get-batch-predictions#request_a_batch_prediction) to the model using a [JSON Lines](https://jsonlines.org/) file to specify a list of input instances with text prompts to generate predictions. For more details on configuring batch prediction jobs, see how to [format your input data](https://cloud.google.com/vertex-ai/docs/predictions/get-batch-predictions#input_data_requirements) and [choose compute settings](https://cloud.google.com/vertex-ai/docs/predictions/get-batch-predictions#choose_machine_type_and_replica_count)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "-5BC3KlAJXtj",
      "metadata": {
        "cellView": "form",
        "id": "-5BC3KlAJXtj"
      },
      "outputs": [],
      "source": [
        "# @title #### Generate predictions in batch from text prompts\n",
        "\n",
        "# @markdown This section shows an example of using TxGemma to generate predictions in batch from text prompts.\n",
        "\n",
        "batch_predict_instances = [\n",
        "    {\"prompt\": TDC_PROMPT, \"max_tokens\": 8, \"temperature\": 0},\n",
        "    {\"prompt\": TDC_PROMPT, \"max_tokens\": 8, \"temperature\": 0},\n",
        "]\n",
        "\n",
        "# Write instances to JSON Lines file\n",
        "os.makedirs(\"batch_predict_input\", exist_ok=True)\n",
        "instances_filename = \"gcs_instances.jsonl\"\n",
        "with open(f\"batch_predict_input/{instances_filename}\", \"w\") as f:\n",
        "   for line in batch_predict_instances:\n",
        "       json_str = json.dumps(line)\n",
        "       f.write(json_str)\n",
        "       f.write(\"\\n\")\n",
        "\n",
        "# Copy the file to Cloud Storage\n",
        "batch_predict_prefix = f\"batch-predict-{MODEL_ID}\"\n",
        "! gcloud storage cp ./batch_predict_input/{instances_filename} {BUCKET_URI}/{batch_predict_prefix}/input/{instances_filename}\n",
        "\n",
        "batch_predict_job_name = common_util.get_job_name_with_datetime(prefix=f\"batch-predict-{MODEL_ID}\")\n",
        "\n",
        "batch_predict_job = models[\"model\"].batch_predict(\n",
        "    job_display_name=batch_predict_job_name,\n",
        "    gcs_source=os.path.join(BUCKET_URI, batch_predict_prefix, f\"input/{instances_filename}\"),\n",
        "    gcs_destination_prefix=os.path.join(BUCKET_URI, batch_predict_prefix, \"output\"),\n",
        "    machine_type=MACHINE_TYPE,\n",
        "    accelerator_type=ACCELERATOR_TYPE,\n",
        "    accelerator_count=ACCELERATOR_COUNT,\n",
        "    service_account=SERVICE_ACCOUNT,\n",
        ")\n",
        "\n",
        "batch_predict_job.wait()\n",
        "\n",
        "print(batch_predict_job.display_name)\n",
        "print(batch_predict_job.resource_name)\n",
        "print(batch_predict_job.state)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "NhAN2phRSlL0",
      "metadata": {
        "cellView": "form",
        "id": "NhAN2phRSlL0"
      },
      "outputs": [],
      "source": [
        "# @title #### Get prediction results\n",
        "\n",
        "# @markdown This section shows an example of [retrieving batch prediction results](https://cloud.google.com/vertex-ai/docs/predictions/get-batch-predictions#retrieve_batch_prediction_results) from the JSON Lines file(s) in the output Cloud Storage location.\n",
        "\n",
        "\n",
        "def download_gcs_files_as_json(gcs_files_prefix):\n",
        "    \"\"\"Download specified files from Cloud Storage and convert content to JSON.\"\"\"\n",
        "    lines = []\n",
        "    client = storage.Client()\n",
        "    bucket = storage.bucket.Bucket.from_string(BUCKET_NAME, client)\n",
        "    blobs = bucket.list_blobs(prefix=gcs_files_prefix)\n",
        "    for blob in blobs:\n",
        "        with blob.open(\"r\") as f:\n",
        "            for line in f:\n",
        "                lines.append(json.loads(line))\n",
        "    return lines\n",
        "\n",
        "\n",
        "batch_predict_output_dir = batch_predict_job.output_info.gcs_output_directory\n",
        "batch_predict_output_files_prefix = os.path.join(\n",
        "    batch_predict_output_dir.replace(f\"{BUCKET_NAME}/\", \"\"),\n",
        "    \"prediction.results\"\n",
        ")\n",
        "batch_predict_results = download_gcs_files_as_json(\n",
        "    gcs_files_prefix=batch_predict_output_files_prefix\n",
        ")\n",
        "\n",
        "# Display first two batch prediction results\n",
        "for i, line in enumerate(batch_predict_results[:2]):\n",
        "    prediction = line[\"prediction\"]\n",
        "    print(prediction)"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "ZFRiJ2o_LhoE",
      "metadata": {
        "id": "ZFRiJ2o_LhoE"
      },
      "source": [
        "## Next steps\n",
        "\n",
        "Explore the other [notebooks](https://github.com/google-gemini/gemma-cookbook/blob/main/TxGemma) to learn what else you can do with the model."
      ]
    },
    {
      "cell_type": "markdown",
      "id": "rQMUmDYr6O_O",
      "metadata": {
        "id": "rQMUmDYr6O_O"
      },
      "source": [
        "## Clean up resources"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "iOSY_r6mHYui",
      "metadata": {
        "cellView": "form",
        "id": "iOSY_r6mHYui"
      },
      "outputs": [],
      "source": [
        "# @markdown  Delete the experiment models and endpoints to recycle the resources\n",
        "# @markdown  and avoid unnecessary continuous charges that may incur.\n",
        "\n",
        "# Undeploy model and delete endpoint.\n",
        "for endpoint in endpoints.values():\n",
        "    endpoint.delete(force=True)\n",
        "\n",
        "# Delete models.\n",
        "for model in models.values():\n",
        "    model.delete()\n",
        "\n",
        "delete_bucket = False  # @param {type:\"boolean\"}\n",
        "if delete_bucket:\n",
        "    ! gsutil -m rm -r $BUCKET_NAME"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "name": "[TxGemma]Quickstart_with_Model_Garden.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
